{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Copyright statement"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    " # Copyright (c) 2023, salesforce.com, inc.\n",
    " # All rights reserved.\n",
    " # SPDX-License-Identifier: BSD-3-Clause\n",
    " # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# X-InstructBLIP Demo\n",
    "\n",
    "Before proceeding **download the Vicuna v1.1 model weights** following the instructions [here](https://github.com/lm-sys/FastChat). "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "LLM_MODEL_PATH = \"<add/llm/path/here>\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## SETUP THE ENVIRONMENT\n",
    "!git clone https://github.com/artemisp/LAVIS-XInstructBLIP.git\n",
    "!cd LAVIS-XInstructBLIP && python -m pip install -e .\n",
    "!python -m pip install --upgrade https://github.com/unlimblue/KNN_CUDA/releases/download/0.2/KNN_CUDA-0.2-py3-none-any.whl\n",
    "!wget -P /usr/bin https://github.com/unlimblue/KNN_CUDA/raw/master/ninja"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gradio as gr\n",
    "import torch\n",
    "import argparse\n",
    "import numpy as np\n",
    "from omegaconf import OmegaConf\n",
    "from lavis.common.registry import registry\n",
    "import random\n",
    "\n",
    "import trimesh\n",
    "import pyvista as pv"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prelim\n",
    "Set up seeds for reproducibility and identify device type. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def setup_seeds(seed=42):\n",
    "    seed = seed\n",
    "\n",
    "    random.seed(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "\n",
    "device = torch.device(\"cuda\") if torch.cuda.is_available() else \"cpu\"\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3D file to point cloud"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load Raw Files\n",
    "\n",
    "#https://github.com/mikedh/trimesh/issues/507\n",
    "def as_mesh(scene_or_mesh):\n",
    "    \"\"\"\n",
    "    Convert a possible scene to a mesh.\n",
    "\n",
    "    If conversion occurs, the returned mesh has only vertex and face data.\n",
    "    \"\"\"\n",
    "    if isinstance(scene_or_mesh, trimesh.Scene):\n",
    "        if len(scene_or_mesh.geometry) == 0:\n",
    "            mesh = None  # empty scene\n",
    "        else:\n",
    "            # we lose texture information here\n",
    "            mesh = trimesh.util.concatenate(\n",
    "                tuple(trimesh.Trimesh(vertices=g.vertices, faces=g.faces)\n",
    "                    for g in scene_or_mesh.geometry.values()))\n",
    "    else:\n",
    "        assert(isinstance(scene_or_mesh, trimesh.Trimesh))\n",
    "        mesh = scene_or_mesh\n",
    "    return mesh\n",
    "\n",
    "def convert_mesh_to_numpy(mesh_file, npoints=8192):\n",
    "    print(\"Loading point cloud.\")\n",
    "    # Load the mesh using trimesh\n",
    "    mesh = trimesh.load_mesh(mesh_file, force='mesh')\n",
    "    mesh = as_mesh(mesh)\n",
    "\n",
    "    # Subsample or upsample the mesh to have exactly npoints points\n",
    "    vertices = mesh.vertices\n",
    "    num_points = len(vertices)\n",
    "    if num_points < npoints:\n",
    "        # Upsample the mesh by repeating vertices\n",
    "        repetitions = int(np.ceil(npoints / num_points))\n",
    "        vertices = np.repeat(vertices, repetitions, axis=0)[:npoints]\n",
    "    elif num_points > npoints:\n",
    "        # Subsample the mesh to the desired number of points\n",
    "        # indices = trimesh.sample.sample_surface(mesh, npoints)[0]\n",
    "        vertices = mesh.vertices#[indices]\n",
    "    print(\"Point cloud loaded..\")\n",
    "    \n",
    "    return vertices\n",
    "\n",
    "\n",
    "\n",
    "def load_mesh(mesh_file_name):\n",
    "    return mesh_file_name"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load Preprocessors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Load Preprocessors\n",
    "from lavis.processors.ulip_processors import ULIPPCProcessor\n",
    "from lavis.processors.clip_processors import ClipImageEvalProcessor\n",
    "from lavis.processors.audio_processors import BeatsAudioProcessor\n",
    "from lavis.processors.alpro_processors import AlproVideoEvalProcessor\n",
    "\n",
    "pc_pocessor = ULIPPCProcessor()\n",
    "image_pocessor = ClipImageEvalProcessor()\n",
    "audio_processor = BeatsAudioProcessor(model_name='iter3', sampling_rate=16000, n_frames=2, is_eval=False, frame_length=512)\n",
    "video_processor = AlproVideoEvalProcessor(n_frms=4, image_size=224)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load Model from LAVIS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from lavis.models.blip2_models.blip2_vicuna_xinstruct import Blip2VicunaXInstruct\n",
    "model = \"vicuna7b_v2\"\n",
    "cfg_path  = {\n",
    "        \"vicuna13b\": './configs/vicuna13b.yaml',\n",
    "        \"vicuna7b\": './configs/vicuna7b.yaml',\n",
    "        \"no_init\": './configs/vicuna7b_no_init.yaml',\n",
    "        \"projection\": './configs/vicuna7b_projection.yaml'\n",
    "        \"vicuna7b_v2\": './configs/vicuna7b_v2.yaml'\n",
    "    }\n",
    "    \n",
    "config = OmegaConf.load(cfg_path[args.model])\n",
    "config.get(\"model\", None).llm_model = LLM_MODEL_PATH\n",
    "print(cfg_path[args.model])\n",
    "print('Loading model...')\n",
    "model_cls = registry.get_model_class(config.get(\"model\", None).arch)\n",
    "model =  model_cls.from_config(config.get(\"model\", None))\n",
    "model.to(device)\n",
    "print('Loading model done!')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Inference Function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def inference(image, point_cloud, audio, video, prompt, qformer_prompt, min_len, max_len, beam_size, len_penalty, repetition_penalty, top_p, decoding_method):\n",
    "    if qformer_prompt == \"\" or qformer_prompt == None:\n",
    "        qformer_prompt = prompt\n",
    "    use_nucleus_sampling = decoding_method == \"Nucleus sampling\"\n",
    "    print(image, point_cloud, audio, video, prompt, min_len, max_len, beam_size, len_penalty, repetition_penalty, top_p, use_nucleus_sampling)\n",
    "    if image is not None:\n",
    "        image = image_pocessor(image).unsqueeze(0).to(device)\n",
    "    if point_cloud is not None:\n",
    "        point_cloud = convert_mesh_to_numpy(point_cloud)\n",
    "        point_cloud = pc_pocessor(point_cloud).unsqueeze(0).to(device)\n",
    "    if audio is not None:\n",
    "        audio = audio_processor(audio).unsqueeze(0).to(device)\n",
    "    if video is not None:\n",
    "        video = video_processor(video).unsqueeze(0).to(device)\n",
    "    \n",
    "    samples = {\"prompt\": prompt}\n",
    "    if image is not None:\n",
    "        samples[\"image\"] = image\n",
    "    if point_cloud is not None:\n",
    "        samples[\"pc\"] = point_cloud\n",
    "    if audio is not None:\n",
    "        samples[\"audio\"] = audio\n",
    "    if video is not None:\n",
    "        samples[\"video\"] = video\n",
    "\n",
    "    output = model.generate(\n",
    "        samples,\n",
    "        length_penalty=float(len_penalty),\n",
    "        repetition_penalty=float(repetition_penalty),\n",
    "        num_beams=beam_size,\n",
    "        max_length=max_len,\n",
    "        min_length=min_len,\n",
    "        top_p=top_p,\n",
    "        use_nucleus_sampling=use_nucleus_sampling,\n",
    "        special_qformer_input_prompt=qformer_prompt\n",
    "    )\n",
    "\n",
    "    return output[0]\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Demo"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "\n",
    "setup_seeds()\n",
    "\n",
    "image_input = gr.Image(type=\"pil\")\n",
    "\n",
    "pc_input = gr.Model3D()\n",
    "\n",
    "audio_input = gr.Audio(sources=[\"upload\"], type=\"filepath\")\n",
    "\n",
    "video_input = gr.Video()\n",
    "\n",
    "min_len = gr.Slider(\n",
    "    minimum=1,\n",
    "    maximum=50,\n",
    "    value=1,\n",
    "    step=1,\n",
    "    interactive=True,\n",
    "    label=\"Min Length\",\n",
    ")\n",
    "\n",
    "max_len = gr.Slider(\n",
    "    minimum=10,\n",
    "    maximum=500,\n",
    "    value=250,\n",
    "    step=5,\n",
    "    interactive=True,\n",
    "    label=\"Max Length\",\n",
    ")\n",
    "\n",
    "sampling = gr.Radio(\n",
    "    choices=[\"Beam search\", \"Nucleus sampling\"],\n",
    "    value=\"Beam search\",\n",
    "    label=\"Text Decoding Method\",\n",
    "    interactive=True,\n",
    ")\n",
    "\n",
    "top_p = gr.Slider(\n",
    "    minimum=0.5,\n",
    "    maximum=1.0,\n",
    "    value=0.9,\n",
    "    step=0.1,\n",
    "    interactive=True,\n",
    "    label=\"Top p\",\n",
    ")\n",
    "\n",
    "beam_size = gr.Slider(\n",
    "    minimum=1,\n",
    "    maximum=10,\n",
    "    value=5,\n",
    "    step=1,\n",
    "    interactive=True,\n",
    "    label=\"Beam Size\",\n",
    ")\n",
    "\n",
    "len_penalty = gr.Slider(\n",
    "    minimum=-1,\n",
    "    maximum=2,\n",
    "    value=1,\n",
    "    step=0.2,\n",
    "    interactive=True,\n",
    "    label=\"Length Penalty\",\n",
    ")\n",
    "\n",
    "repetition_penalty = gr.Slider(\n",
    "    minimum=0.,\n",
    "    maximum=3,\n",
    "    value=1.5,\n",
    "    step=0.2,\n",
    "    interactive=True,\n",
    "    label=\"Repetition Penalty\",\n",
    ")\n",
    "\n",
    "\n",
    "prompt_textbox = gr.Textbox(label=\"Prompt:\", placeholder=\"prompt\", lines=2)\n",
    "qformer_prompt_textbox = gr.Textbox(label=\"Qformer Prompt:\", placeholder=\"prompt\", lines=2)\n",
    "\n",
    "iface = gr.Interface(\n",
    "    fn=inference,\n",
    "    inputs=[image_input, pc_input, audio_input, video_input, prompt_textbox, qformer_prompt_textbox, min_len, max_len, beam_size, len_penalty, repetition_penalty, top_p, sampling],\n",
    "    outputs=\"text\",\n",
    "    allow_flagging=\"never\",\n",
    "    examples=examples\n",
    ")\n",
    "\n",
    "iface.launch(share=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
